{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.1\n",
      "IPython 6.0.0\n",
      "\n",
      "tensorflow 1.2.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p tensorflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Logistic Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implementation of *classic* logistic regression for binary class labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAACqCAYAAAD1E6s4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFexJREFUeJzt3X+MVNd1B/DvYbKOVkq0hHjlyMsuUMddCxkqxMom4g8r\nxhWOCzGxGxTcWqG1hCI5ShMiWpAtZCNXpkIijVX3D1QsWtkhosLeONgVsaGKVStQ75oEsAmRHdew\nm0jBtSCRuhLL7ukfb4edmX1v5r1598398b4fCa3nMTtzWL+75829550rqgoiIiJfzLMdABERURZM\nXERE5BUmLiIi8goTFxEReYWJi4iIvMLERUREXmHiIiIirzBxERGRV5i4iIjIK5+w8aY33nijLl68\n2MZbExkxOjr6kar22o6jimOKQpB2XFlJXIsXL8bIyIiNtyYyQkQ+tB1DLY4pCkHaccWpQiIi8goT\nFxEReYWJi4iIvMLEFbLTh4Dv3Q48MT/6evqQ7YiIisVzvhSsFGdQB5w+BPz4W8DkRPT4ysXoMQAs\n32gvLqKi8JwvDX7iCtWxXbMDuGpyIjpOFCKe86XBxBWqK2PZjhP5jud8aTBxhapnYbbjRL7jOV8a\nTFyhWrMT6OquP9bVHR0nChHP+dJg4grV8o3A+meAnn4AEn1d/wwXqckdpisAec6XBqsKQ7Z8Iwct\nuamoCkCe86XAT1xE1HmsAKQcmLiIqPNYAUg5MHERUeexApByYOIios5jBSDlwMRFRJ3HCkDKgVWF\nRI4RkX4A/wbgJgAKYJ+qft9uVAVgBSC1iYmLyD3XAHxXVd8WkU8DGBWR11T1XduBEbmAU4VEjlHV\n36rq2zP//QcA5wD02Y2KyB1MXEQOE5HFAFYAOBnzd1tEZERERi5dutTp0IisYeIicpSIfArAYQDf\nVtXfN/69qu5T1SFVHert7e18gESWMHEROUhEuhAlrRdU9UXb8RC5hImLyDEiIgD2Azinqnttx0Pk\nGiauMjLdlZtMWw3gYQB3i8jPZ/7cZzsoIlfkLocvzT0noSiqKzcZo6r/BUBsx0HkKhOfuKr3nCwF\nsArAoyKy1MDrUhHYlZuIPJc7cfGeE8+wKzcRec5o54xW95wA2AIAAwMDJt+WsuhZGE0Pxh0nMunI\nVmD0AKBTgFSAlZuBdaw1ofyMFWfwnhPDjmwFnlwAPNETfT2y1czrsis3dcKRrcDI/ihpAdHXkf3m\nzmMqNSOJi/ecGFbkoGdXbuqE0QPZjhNlYKKqkPecmNZs0JuYamFXbipa9aIr7XGiDEx84uI9J6Zx\n0JPvpJLtOFEGuT9x8Z6TAkglPklx0JMvVm6OprfjjhPlxM4ZppnoSpE0uLMOenbIIFvW7QWGHpm9\n2JJK9NjEVDfP69LjRpImmepKUR3ceUqJ2SGDbFu313z5O89rAhOXWc26UmQdVHkHvclYiFzB85rA\nqUKzXOpK4VIsRKbwvCYwcZmV1H3CRlcKl2IhMoXnNYGJy6w1O4F5XfXH5nU170pR1EIzO2RQiHhe\nE7jGZZ5I88e1ilxorn7/sV3RNErPwmhwcx2AfMbzmsDEZdaxXcDU1fpjU1eTF46LXmhmhwwKEc/r\n0uNUoUlZF4650ExElBkTl0lZF4650ExElFnYiavIO+zjXjvrwjEXmimBiDwnIr8TkbO2YyFyTbiJ\nq1r4cOUiAJ0tfDCRvJJeG8i2ZQi3GKFkBwDcazsIIheFW5xRZOFDs9f+ztlsr8+FZoqhqm/M7Cju\nr9OHOl/9Z+M9qePCTVxFFj6wqIIcICJbAGwBgIGBAcvRNLDRU5B9DEsj3KnCIgsfWFRBDlDVfao6\npKpDvb29tsOp12xWIqT3JCvCTVxFFj6s2QnMa9gba14lOp5UEMKtGKhMbMxKcCakNMKdKizyDvsL\nJ4Dpho0ep6eAU88DY/89d6riwgngFz/gFAaVR8/CmeKlmOMhvSdZEe4nLiBKCt85CzxxOXvRRDOj\nB+KPf/DT+KmK0QOcwqBMROQggJ8BGBSRMRF5xHZMmdi41YO3l5RGuJ+4iqRTrZ+T5vmcwqAEqrrJ\ndgy5LN8YzTTUbob6Jw+ZuXhMqhxkH8PSYOJqh1SyJa+k53MKg0J1+lA0PV4973UqejywKl8iaVU5\nyNtLSiHsqcKsBRFHtgJPLgCe6Im+Htka/7yVm+OPL7krfqpi5WZ3tjtpw/CpcazefRxLtr+C1buP\nY/jUuLVYyBNFVfixcpAQcuLK2jnjyFZgZH/9FeLI/uTkFeezn4/vhDGwqr3tToro+pHR8Klx7Hjx\nDMYvT0ABjF+ewI4XzzB5UXNFVfixcpAQcuLKemWWVHARd7zZc+MKQpptd2Ii9gLtOXoeE5P105wT\nk1PYc/R8x2MhjxR1ryPvoSSEnLiyXpklrVnFHc/y3HZiceiq8jeXJzIdJwJQXIUfKwcJISeurFdm\nUkl/PMtz24nFoavKm+d3ZzpOBKC4BtJsTA2A687hJq6sV2ZJBRdxx5s9N7DtTratHUR3V31C7u6q\nYNvawY7HQp4p6j7Kol7XE1x3DjlxZb0yW7cXGHpk9lOTVKLH6/bOfe7AqrmfrqqPA9vuZMOKPjz9\nwDL0ze+GAOib342nH1iGDSv6Oh4LEXHdGQBEVTv+pkNDQzoyMtLx9zXme7fHt5ZJvF+rP7oypGCI\nyKiqDtmOo8r7MUWpLdn+CuJ+awuAD3b/WafDMSrtuAr3E1eRshZ4sFSXiAzhujMTV3uyFniwVJeI\nDOG6s28tn7Lubpr0/Ly7pK7ZCfzo0fp7syo3ACseru8CD3hVqjt8ahx7jp7Hby5P4Ob53di2drB0\na1lB/wy4O3AQqudj0nka9Dk8w5/ElXV306Tnm9pipHFtUDUq2hhY5eUvh2qlUnXRt1qpBCC4kz5J\n0D8D7g4clA0r+mLPyaDP4Rr+TBVm7SaR9HwTW4wc2wVMT9Yfm56MjntaqstKpcB/Bg51Y6HiBH0O\n1/AncZnqMmGigMKhzhamsENG4D+DAM9Zmivoc7iGP4nLVJcJEwUUDnW2MIWVSoH/DAI8Z2muoM/h\nGv4kLlPdJ1Zuzt+VwqHOFqawUinwn0GA5yzNFfQ5XMNIcYaI3Avg+wAqAP5FVXebeN06WXc3TdqB\ndd3e5AKKpKqrf/0y8MFPZ197yV1RJwsPizCStKpUcp2JSirffwZNcXfgUijqHM4zvoqocszdOUNE\nKgB+BeBPAYwBeAvAJlV9N+l7OnKXf2MVFRBdYSa1Tkp6fs8i4KNfzn3+kruAr79sPm7KrLGSCoiu\nMotsTVV054ysF4PsnEFFyTO+sn5vJztn3AHgPVX9tapeBfBDAPcbeN18TFUhxiUtoP4TGFkVWiXV\nzMXgswC+BGApgE0istRuVFRWecZXUWPTROLqA1DbuG9s5lgdEdkiIiMiMnLp0iUDb9uCB3tdkRkB\nVlK5eTFIpZRnfBU1NjtWnKGq+1R1SFWHent7i39DD/a6IjMCrKRy82KQSinP+CpqbJoozhgH0F/z\neOHMsfaZaNW0Zmf8mlWzKsSXvlF/n5dUgM/eGj9deONtM13i3Vjofnz4DA6evIgpVVREsOnOfgwt\nWpBpUTTrIqqN1jJx77lt7SC2/fsvMDk9u17bNU+Cq6RqpKr7AOwDojUuy+FQoLatHYxdp0ozvvJ8\nbzMmEtdbAG4VkSWIEtbXADzU9quZatWUtYrqwom5NyfrFPDpm4CPzgONGwl8/P5s9wzL7XMeHz6D\n509cuP54ShXPn7iAH5y4gOmZY61av2RtFWOjtUzSez64si/a06FW42O/mL8YJGpTnkrFoqocjezH\nJSL3AfhHRBVQz6nq3zd7ftMKKFt7XT25ILmrRlqW9t26ZcermEr5/7Fvfjfe3H73nOOrdx/HeMy8\ns6nnm5D0nhWR2H9/kbEUWVUoIp9AVKm7BlHCegvAQ6r6TtL3lKmqsAxNZMsq7bgych+Xqr4K4FUT\nr2Vtr6u8SQuwVuCRNmkB2RdLTR03Iem1k/79vhZnqOo1EfkmgKOYvRhMTFplUpYmstSce50zbO11\nlfT6WVgq8KhI+nmxrIulpo6bkPTaSf9+j4szoKqvquofq+otrWYwyiS0Wx+oPe4lrmatmuZ11R+f\n12WuZc3KzfHHl9w1N57KDXNjsdg+Z9Od/a2fhOYFC1lbxWxbO4iuefUJo92CiOFT41i9+ziWbH8F\nq3cfx/Cp+OWcbWsH0VVpeM9KVIhShjY3FOStD9QG9xLX8o1Rd4uefgASfV3/TNSmqfHKOsMnjZbW\n7QWGHpn95CWV6PHXX54bz/3PAhv+eW6MlqoKhxYtQKUhicyT6E+dJj+uDSv68PQDy9A3vxuCaH2o\n5Z3xBgoiqlM/45cnoJid+klKXo01MtDo3585dvJSgLc+UBuMFGdk1dZCclLRhqWCCJckFS3EMVWw\nYKo4I8vr2CgISVJ0y6esylKcYaO9F3VOR4szOoIdLxJlmSYxNaViasomy+twmoh8a4TcbgUkKyeb\n8ydx9SxM+MTFjhc3z+9O/YnL1JRK0ntmff0sr2PqPclvSdvWu6bdCkhWTrbm3hpXEu4nlCiusKKr\nInOWnCoGu0mYKs5IKgr54m29cwo22tlrKG3hB5Fp7VZAsnKyNX8SV1LRBvcTii2suGPxZ+bUMUxN\nK0Y+/NjcGxsozoiL/cGVfTg8Oj6nYANApiKMzIUfRAa1O7XNKfHW/JkqBKIkxUQVq3H65JYd8feD\nHzx5EU9tWJb7/fYcPY/JqfrUODml2HP0fFsbONZ+z+rdxxOvON/cfnfq12925copFypau1PbnBJv\nzZ9PXJRJUjeJLF02minyqtBG4QeRae1Mbef5vjLx6xMXpZbUvy9Ll41mirwqtFH4QWRaqwrIpMpB\n3yonbWDi6pCiy1sbX3/VH30Gb74/dz1r0539RrYvKWq7AsDcVghFxkiURlIFZKvKQV8qJ23hVGEH\nFF0kEPf6b1+4gtW3LLj+Casigr9cNYChRQsyxZIUO5CtUCKLtrp4FPg6RKaxcjAffuLqgKKLBJJe\n/3/+dwLvP31f3fFmhQ9xsTSLPUuhRFamrjh55Uou4vprPvzE1QFFn6RFdp/gACMyjz0X82Hi6oCi\nT9Isr+/D9iVEoWPlYD5MXB1g8iSN6wRRZPcJDjAi8x1YuP6ajz/d4T1noqqwWWdsoL589ou39eLw\n6Hiq57ZTVVj2Acbu8OXBjvSdk3ZcMXF5xNctQELExFUeHEudk3ZccarQI9wCJHwi8lUReUdEpkXE\nmcRYZhxL7mHi8kiRRRjkjLMAHgDwhu1AKMKx5B4mLsOK3EZj29pBdFUathKpxG8l4mJRBbcYaU1V\nz6kq70J1iItjqeyYuAzqyDYajUuSCUuUrlUtcYsR80Rki4iMiMjIpUuXbIcTrA0r+vDgyr66LjQP\nruSN7Taxc4ZBneiQMTndsJXIdPJWIi51jeAWI7NE5HUAn4v5q8dU9UdpX0dV9wHYB0TFGYbCowbD\np8ZxeHT8etPqKVUcHh3H0KIFpTt3XcHEZZBLHTJc43PspqnqPbZjoPR40eUeThUa5FKHDNf4HDuV\nGy+63MPElULaooKiF3F9XiT2OfZOEpGviMgYgC8AeEVEjtqOqex40eUeJq4WshQVFF0Q4VrBRRY+\nx95JqvqSqi5U1U+q6k2qutZ2TGXHiy73sHNGC7xrnuKwc4Z/8rQuY9uzzkg7rlic0QLnt4n812rH\n4VZcqtAlThW2xPltIv9xx+GwhPGJ6/Qh4Ngu4MoY0LMQWLMTWL7RyEtvWzsY2xna1vy2z1MWPsdO\nfuPMSVj8T1ynDwE//hYwOXMCXrkYPQaMJK/qL1YXfuHmne6wyefYyX83z++OXavmzImf/E9cx3bN\nJq2qyYnouKFPXa7Mb/t8I6TPsZP/XJs5oXz8T1xXxrId95jP0x0+x07+eHz4DA6evIgpVVREsOnO\nfjy1YZlTMyeUn/+Jq2dhND0YdzwwPk93+Bw7+eHx4TN4/sSF64+nVK8/riYvJqow5KoqFJE9IvJL\nETktIi+JyHxTgaW2ZifQ1fDLr6s7Oh4YX26EjOs04kvs5K+DJ2MuYJscJ3/lLYd/DcDtqrocwK8A\n7MgfUkbLNwLrnwF6+gFI9HX9M8bWt1ziQ/eJpE4jAJyPnfw2ldBMIek4+SvXVKGq/qTm4QkAf54v\nnDYt3xhkoorj+nRHsyKMN7ff7XTs5LeKSGySqu6jReEweQPyXwP4j6S/5KZ35cAiDLJl0539mY6T\nv1p+4kqz6Z2IPAbgGoAXkl6Hm96VA4swysWlm8qf2rAMAGKrCiksLRNXq03vRGQzgHUA1qiNjr3k\nFN4vUx4u3lT+1IZlTFQlkLeq8F4Afwvgy6r6f2ZCIp/5UEBCZrD/H9mS9z6ufwLwSQCvSbQAekJV\nv5E7KvKa6wUkZAbXM8mWvFWFnzcVCBH5heuZZAu3NSGitvCmcrLF/5ZPlrhUTUXhEJE9ANYDuArg\nfQB/paqX7UYVj/3/WuPviWIwcbXBxWoqCsZrAHao6jUR+QdE3Wj+znJMibiemYy/J4rDqcI2sJqK\niqKqP1HVazMPTwAIr1t0SfD3RHGYuNrAairqEHaj8Rh/TxSHiasNSVVTrKaiNETkdRE5G/Pn/prn\npOpGo6pDqjrU29vbidApA/6eKA4TVxtYTUV5qOo9qnp7zJ9qC7XNiLrR/AW70fiLvyeKw+KMNrCa\niopS043mLnaj8Rt/TxSHiatNrKaigrAbTUD4e6IYTFxEDmE3GqLWuMZFREReERtrvyJyCcCHHX/j\n5m4E8JHtIBK4GluZ41qkqs6U8uUcU67+f0yDsdtRVOypxpWVxOUiERlR1SHbccRxNTbGFQaff16M\n3Q7bsXOqkIiIvMLERUREXmHimrXPdgBNuBob4wqDzz8vxm6H1di5xkVERF7hJy4iIvIKExcREXmF\niauGiHxVRN4RkWkRsV6mKiL3ish5EXlPRLbbjqdKRJ4Tkd+JyFnbsdQSkX4R+U8ReXfm/+Pf2I7J\nB66d92m4OjbScHX8tOLS+GLiqncWwAMA3rAdiIhUADwL4EsAlgLYJCJL7UZ13QEA99oOIsY1AN9V\n1aUAVgF41KGfmcucOe/TcHxspHEAbo6fVpwZX0xcNVT1nKq6sj3pHQDeU9Vfq+pVAD8EcH+L7+kI\nVX0DwMe242ikqr9V1bdn/vsPAM4BYIfTFhw779Nwdmyk4er4acWl8cXE5a4+ABdrHo+Bv4RTE5HF\nAFYAOGk3EioAx4ZltsdX6brDi8jrAD4X81ePVTfyI7+JyKcAHAbwbVX9ve14XMDznkxxYXyVLnGp\n6j22Y0hpHEB/zeOFM8eoCRHpQjSoXlDVF23H4wqPzvs0ODYscWV8carQXW8BuFVElojIDQC+BuBl\nyzE5TaKdF/cDOKeqe23HQ4Xh2LDApfHFxFVDRL4iImMAvgDgFRE5aisWVb0G4JsAjiJaBD2kqu/Y\niqeWiBwE8DMAgyIyJiKP2I5pxmoADwO4W0R+PvPnPttBuc6l8z4Nl8dGGg6Pn1acGV9s+URERF7h\nJy4iIvIKExcREXmFiYuIiLzCxEVERF5h4iIiIq8wcRERkVeYuIiIyCv/DzTREOtIuGjEAAAAAElF\nTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x10f6e0cc0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from io import BytesIO\n",
    "\n",
    "##########################\n",
    "### DATASET\n",
    "##########################\n",
    "\n",
    "ds = np.lib.DataSource()\n",
    "fp = ds.open('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data')\n",
    "\n",
    "x = np.genfromtxt(BytesIO(fp.read().encode()), delimiter=',', usecols=range(2), max_rows=100)\n",
    "y = np.zeros(100)\n",
    "y[50:] = 1\n",
    "\n",
    "np.random.seed(1)\n",
    "idx = np.arange(y.shape[0])\n",
    "np.random.shuffle(idx)\n",
    "x_test, y_test = x[idx[:25]], y[idx[:25]]\n",
    "x_train, y_train = x[idx[25:]], y[idx[25:]]\n",
    "mu, std = np.mean(x_train, axis=0), np.std(x_train, axis=0)\n",
    "x_train, x_test = (x_train - mu) / std, (x_test - mu) / std\n",
    "\n",
    "fig, ax = plt.subplots(1, 2, figsize=(7, 2.5))\n",
    "ax[0].scatter(x_train[y_train == 1, 0], x_train[y_train == 1, 1])\n",
    "ax[0].scatter(x_train[y_train == 0, 0], x_train[y_train == 0, 1])\n",
    "ax[1].scatter(x_test[y_test == 1, 0], x_test[y_test == 1, 1])\n",
    "ax[1].scatter(x_test[y_test == 0, 0], x_test[y_test == 0, 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### HELPER FUNCTIONS\n",
    "##########################\n",
    "\n",
    "def iterate_minibatches(arrays, batch_size, shuffle=False, seed=None):\n",
    "    rgen = np.random.RandomState(seed)\n",
    "    indices = np.arange(arrays[0].shape[0])\n",
    "\n",
    "    if shuffle:\n",
    "        rgen.shuffle(indices)\n",
    "\n",
    "    for start_idx in range(0, indices.shape[0] - batch_size + 1, batch_size):\n",
    "        index_slice = indices[start_idx:start_idx + batch_size]\n",
    "\n",
    "        yield (ary[index_slice] for ary in arrays)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000 | AvgCost: nan | Train/Valid ACC: 0.53/0.40\n",
      "Epoch: 001 | AvgCost: 4.221 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 002 | AvgCost: 1.225 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 003 | AvgCost: 0.610 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 004 | AvgCost: 0.376 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 005 | AvgCost: 0.259 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 006 | AvgCost: 0.191 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 007 | AvgCost: 0.148 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 008 | AvgCost: 0.119 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 009 | AvgCost: 0.098 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 010 | AvgCost: 0.082 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 011 | AvgCost: 0.070 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 012 | AvgCost: 0.061 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 013 | AvgCost: 0.053 | Train/Valid ACC: 1.00/1.00\n",
      "Epoch: 014 | AvgCost: 0.047 | Train/Valid ACC: 1.00/1.00\n",
      "\n",
      "Weights:\n",
      " [[ 3.31176686]\n",
      " [-2.40808702]]\n",
      "\n",
      "Bias:\n",
      " [[-0.01001291]]\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "n_features = x.shape[1]\n",
    "n_samples = x.shape[0]\n",
    "learning_rate = 0.05\n",
    "training_epochs = 15\n",
    "batch_size = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### GRAPH DEFINITION\n",
    "##########################\n",
    "\n",
    "g = tf.Graph()\n",
    "with g.as_default() as g:\n",
    "\n",
    "   # Input data\n",
    "    tf_x = tf.placeholder(dtype=tf.float32,\n",
    "                          shape=[None, n_features], name='inputs')\n",
    "    tf_y = tf.placeholder(dtype=tf.float32,\n",
    "                          shape=[None], name='targets')\n",
    "    \n",
    "    # Model parameters\n",
    "    params = {\n",
    "        'weights': tf.Variable(tf.zeros(shape=[n_features, 1],\n",
    "                                               dtype=tf.float32), name='weights'),\n",
    "        'bias': tf.Variable([[0.]], dtype=tf.float32, name='bias')}\n",
    "\n",
    "    # Logistic Regression\n",
    "    linear = tf.matmul(tf_x, params['weights']) + params['bias']\n",
    "    pred_proba = tf.sigmoid(linear, name='predict_probas')\n",
    "\n",
    "    # Loss and optimizer\n",
    "    r = tf.reshape(pred_proba, [-1])\n",
    "    cost = tf.reduce_mean(tf.reduce_sum((-tf_y * tf.log(r)) - \n",
    "                                        ((1. - tf_y) * tf.log(1. - r))), name='cost')\n",
    "    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n",
    "    train = optimizer.minimize(cost, name='train')\n",
    "                                                  \n",
    "    # Class prediction\n",
    "    pred_labels = tf.round(tf.reshape(pred_proba, [-1]), name='predict_labels')\n",
    "    correct_prediction = tf.equal(tf_y, pred_labels)\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')\n",
    "\n",
    "\n",
    "##########################\n",
    "### TRAINING & EVALUATION\n",
    "##########################\n",
    "    \n",
    "with tf.Session(graph=g) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    \n",
    "    avg_cost = np.nan\n",
    "    count = 1\n",
    "    \n",
    "    for epoch in range(training_epochs):\n",
    "\n",
    "        train_acc = sess.run('accuracy:0', feed_dict={tf_x: x_train,\n",
    "                                                      tf_y: y_train})\n",
    "        valid_acc = sess.run('accuracy:0', feed_dict={tf_x: x_test,\n",
    "                                                      tf_y: y_test}) \n",
    "\n",
    "        print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch, avg_cost / count), end=\"\")\n",
    "        print(\" | Train/Valid ACC: %.2f/%.2f\" % (train_acc, valid_acc))\n",
    "        \n",
    "        avg_cost = 0.\n",
    "        for x_batch, y_batch in iterate_minibatches(arrays=[x_train, y_train],\n",
    "                                                    batch_size=batch_size, \n",
    "                                                    shuffle=True, seed=123):\n",
    "            \n",
    "            feed_dict = {'inputs:0': x_batch,\n",
    "                         'targets:0': y_batch}\n",
    "            _, c = sess.run(['train', 'cost:0'], feed_dict=feed_dict)\n",
    "\n",
    "            avg_cost += c\n",
    "            count += 1\n",
    "\n",
    "    weights, bias = sess.run(['weights:0', 'bias:0'])\n",
    "    print('\\nWeights:\\n', weights)\n",
    "    print('\\nBias:\\n', bias)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
